"""
Input: 
    1. KPSS_2020_public.csv (raw data)
    2. patentsberta_embedding_matrix.npy (generated by append_raw.py)
Output:
    1. patent_dict_kpss.csv
    2. patentsberta_embedding_matrix_kpss.npy
Date: 2024/05/25
"""

import pandas as pd
import numpy as np
import os

os.chdir(r'/Volumes/Zihao_SSD2/PatentsView/')

# load patent dict [idx, patent_id, patent_year]
df = pd.read_csv("patentsberta/patent_dict.csv")

# load kpss
df_kpss = pd.read_csv("rawdata/KPSS_2020_public.csv", usecols=["patent_num"], low_memory=False)
print("num patents in kpss: ", len(df_kpss))
df_kpss.rename(columns={"patent_num": "patent_id"}, inplace=True)

# merge kpss and all patents
df_merged = df.merge(df_kpss, on=["patent_id"], how="inner")
df_merged = (
    df_merged.rename(columns={"idx": "old_idx"})
    .reset_index()
    .rename(columns={"index": "idx"})
)
print("num patents in merged: ", len(df_merged))
df_merged.to_csv("patentsberta/patent_dict_kpss.csv", index=False)
print("Patent dictionary saved.")

# load embeddings
emb_array = np.load("patentsberta/patentsberta_embedding_matrix.npy")
print(emb_array.shape)

# keep rows in embedding array that are in kpss
rows_to_keep = df_merged.old_idx.values
# use numpy array to select desired rows
emb_array_kpss = emb_array[rows_to_keep]
print(emb_array_kpss.shape)

save_name = "patentsberta/patentsberta_embedding_matrix_kpss.npy"
np.save(save_name, emb_array_kpss)
print("Patent embeddings saved.")
